Transformer Architecture
前面我们完成了自己训练一个小模型,今天我们结合论文来学习一下Transformer的理论知识~
概述
Transformer 模型于 2017 年在论文《注意力就是你所需要的一切》中首次提出。Transformer 架构旨在训练语言翻译目的模型。然而,OpenAI 的团队发现 transformer 架构是角色预测的关键解决方案。一旦对整个互联网数据进行训练,该模型就有可能理解任何文本的上下文,并连贯地完成任何句子,就像人类一样。
该模型由两部分组成:编码器和解码器。通常,仅编码器体系结构擅长从文本中提取信息以执行分类和回归等任务,而仅解码器模型则专门用于生成文本。例如,专注于文本生成的 GPT 属于仅解码器模型的范畴。
注意
GPT 模型仅使用 transformer 架构的解码器部分。
让我们在训练模型时了解架构的关键思想。
我画了一张图来说明类似 GPT 的仅解码器转换器架构的训练过程:
-
首先,我们需要一系列输入字符作为训练数据。这些输入被转换为矢量嵌入格式。
-
接下来,我们在向量嵌入中添加位置编码,以捕获每个字符在序列中的位置。
-
随后,模型通过一系列计算操作处理这些输入嵌入,最终为给定的输入文本生成可能的下一个字符的概率分布。
-
该模型根据训练数据集中的实际后续特征评估预测结果,并相应地调整概率或“权重”。
-
最后,该模型迭代地完善了这一过程,不断更新其参数以提高未来预测的精度。
让我们深入了解每个步骤的细节。
1 Tokenization
标记化是转换器模型的第一步,该模型:
将输入句子转换为数字表示格式。
标记化是将文本划分为称为标记的较小单元的过程,这些单元可以是单词、子单词、短语或字符。因为将短语分解成更小的部分有助于模型识别文本的底层结构并更有效地处理它。
例如:
Chapter 1: Building Rapport and Capturing
上面这句话可以切成:
Chapter
, , , , , , , , , , , ,1``:``Building``Rap``port``and``Capturing
它被标记为 10 个数字:
如您所见,数字 220 用于表示空格字符。有许多方法可以将字符标记为整数。对于我们的示例数据集,我们将使用 tiktoken 库。
出于演示目的,我将使用一个小型教科书数据集(来自 Hugging Face),其中包含 460k 个字符用于我们的训练。
-
文件大小: 450Kb
-
词汇量:3,771(表示唯一单词/子单词)
我们的训练数据包含 3,771 个不同字符的词汇量。用于标记我们的教科书数据集的最大数量是 ,它被映射到一个字符。100069``Clar
一旦我们有了标记化映射,我们就可以为数据集中每个字符找到相应的整数索引。我们将利用这些分配的整数索引作为标记,而不是在与模型交互时使用整个单词。
2 Word Embeddings
首先,让我们构建一个包含词汇表中所有字符的查找表。从本质上讲,该表由一个填充了随机初始化数字的矩阵组成。
给定我们拥有的最大标记数是 ,并考虑维度为 64(原始论文使用 512 维,表示为 d_model),生成的查找表变为 100,069 × 64 矩阵,这称为标记嵌入查找表。表示如下:100069
1Token Embedding Look-Up Table: 2 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 30 0.625765 0.025510 0.954514 0.064349 -0.502401 -0.202555 -1.567081 -1.097956 0.235958 -0.239778 ... 0.420812 0.277596 0.778898 1.533269 1.609736 -0.403228 -0.274928 1.473840 0.068826 1.332708 41 -0.497006 0.465756 -0.257259 -1.067259 0.835319 -1.956048 -0.800265 -0.504499 -1.426664 0.905942 ... 0.008287 -0.252325 -0.657626 0.318449 -0.549586 -1.464924 -0.557690 -0.693927 -0.325247 1.243933 52 1.347121 1.690980 -0.124446 -1.682366 1.134614 -0.082384 0.289316 0.835773 0.306655 -0.747233 ... 0.543340 -0.843840 -0.687481 2.138219 0.511412 1.219090 0.097527 -0.978587 -0.432050 -1.493750 63 1.078523 -0.614952 -0.458853 0.567482 0.095883 -1.569957 0.373957 -0.142067 -1.242306 -0.961821 ... -0.882441 0.638720 1.119174 -1.907924 -0.527563 1.080655 -2.215207 0.203201 -1.115814 -1.258691 74 0.814849 -0.064297 1.423653 0.261726 -0.133177 0.211893 1.449790 3.055426 -1.783010 -0.832339 ... 0.665415 0.723436 -1.318454 0.785860 -1.150111 1.313207 -0.334949 0.149743 1.306531 -0.046524 8... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 9100064 -0.898191 -1.906910 -0.906910 1.838532 2.121814 -1.654444 0.082778 0.064536 0.345121 0.262247 ... 0.438956 0.163314 0.491996 1.721039 -0.124316 1.228242 0.368963 1.058280 0.406413 -0.326223 10100065 1.354992 -1.203096 -2.184551 -1.745679 -0.005853 -0.860506 1.010784 0.355051 -1.489120 -1.936192 ... 1.354665 -1.338872 -0.263905 0.284906 0.202743 -0.487176 -0.421959 0.490739 -1.056457 2.636806 11100066 -0.436116 0.450023 -1.381522 0.625508 0.415576 0.628877 -0.595811 -1.074244 -1.512645 -2.027422 ... 0.436522 0.068974 1.305852 0.005790 -0.583766 -0.797004 0.144952 -0.279772 1.522029 -0.629672 12100067 0.147102 0.578953 -0.668165 -0.011443 0.236621 0.348374 -0.706088 1.368070 -1.428709 -0.620189 ... 1.130942 -0.739860 -1.546209 -1.475937 -0.145684 -1.744829 0.637790 -1.064455 1.290440 -1.110520 13100068 0.415268 -0.345575 0.441546 -0.579085 1.110969 -1.303691 0.143943 -0.714082 -1.426512 1.646982 ... -2.502535 1.409418 0.159812 -0.911323 0.856282 -0.404213 -0.012741 1.333426 0.372255 0.722526 14 15[100,069 rows x 64 columns]
其中每行代表一个字符(按其标记编号索引),每列代表一个维度。
现在,您可以将“维度”视为角色的特征或方面。在我们的例子中,我们指定了 64 个维度,这意味着我们将能够以 64 种不同的方式理解一个角色的文本含义,例如将其分类为名词、动词、形容词等。
假设,现在我们有一个 16 context_length的训练输入示例,即:
" . By mastering the art of identifying underlying motivations and desires, we equip ourselves with
"
现在,我们通过使用其整数索引来查找嵌入表,从而检索每个标记化字符(或单词)的嵌入向量。因此,我们得到了它们各自的输入嵌入:
[ 627, 1383, 88861, 279, 1989, 315, 25607, 16940, 65931, 323, 32097, 11, 584, 26458, 13520, 449]
在变压器架构中,多个输入序列同时并行处理,通常称为多批处理。让我们将batch_size设置为 4。因此,我们将一次处理四个随机选择的句子作为我们的输入。
1Input Sequence Batch: 2 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 30 627 1383 88861 279 1989 315 25607 16940 65931 323 32097 11 584 26458 13520 449 41 15749 311 9615 3619 872 6444 6 3966 11 10742 11 323 32097 13 3296 22815 52 13189 315 1701 5557 304 6763 374 88861 7528 10758 7526 13 4314 7526 2997 2613 63 323 6376 2867 26470 1603 16661 264 49148 627 18 13 81745 48023 75311 7246 66044 7 8[4 rows x 16 columns]
每行代表一个句子;每列是该句子从第 0 位到第 15 位的字符。
结果,我们现在有一个矩阵,表示 4 批 16 个字符的输入。该矩阵的形状为 (batch_size, context_length) = [4, 16]。
回顾一下,我们将输入嵌入查找表定义为大小为 100,069 × 64 的矩阵。下一步是获取我们的输入序列矩阵并将其映射到这个嵌入矩阵上,以获得我们的输入嵌入。
在这里,我们将重点分解输入序列矩阵的每一行,从第一行开始。首先,我们将此开始行从其原始尺寸 (1, context_length) = [1, 16] 重塑为 (context_length, 1) = [16, 1] 的新格式。随后,我们将这个重组后的行覆盖在我们之前建立的嵌入矩阵大小 (vocab_size, d_model) = [100069, 64] 上,从而将匹配的嵌入向量替换为给定上下文窗口中存在的每个字符。生成的输出是形状为 (context_length, d_model) = [16, 64] 的矩阵。
输入序列批处理的第一行:
1Input Embedding: 2 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 30 1.051807 -0.704369 -0.913199 -1.151564 0.582201 -0.898582 0.984299 -0.075260 -0.004821 -0.743642 ... 1.151378 0.119595 0.601200 -0.940352 0.289960 0.579749 0.428623 0.263096 -0.773865 -0.734220 41 -0.293959 -1.278850 -0.050731 0.862562 0.200148 -1.732625 0.374076 -1.128507 0.281203 -1.073113 ... -0.062417 -0.440599 0.800283 0.783043 1.602350 -0.676059 -0.246531 1.005652 -1.018667 0.604092 52 -0.292196 0.109248 -0.131576 -0.700536 0.326451 -1.885801 -0.150834 0.348330 -0.777281 0.986769 ... 0.382480 1.315575 -0.144037 1.280103 1.112829 0.438884 -0.275823 -2.226698 0.108984 0.701881 63 0.427942 0.878749 -0.176951 0.548772 0.226408 -0.070323 -1.865235 1.473364 1.032885 0.696173 ... 1.270187 1.028823 -0.872329 -0.147387 -0.083287 0.142618 -0.375903 -0.101887 0.989520 -0.062560 74 -1.064934 -0.131570 0.514266 -0.759037 0.294044 0.957125 0.976445 -1.477583 -1.376966 -1.171344 ... 0.231112 1.278687 0.254688 0.516287 0.621753 0.219179 1.345463 -0.927867 0.510172 0.656851 85 2.514588 -1.001251 0.391298 -0.845712 0.046932 -0.036732 1.396451 0.934358 -0.876228 -0.024440 ... 0.089804 0.646096 -0.206935 0.187104 -1.288239 -1.068143 0.696718 -0.373597 -0.334495 -0.462218 96 0.498423 -0.349237 -1.061968 -0.093099 1.374657 -0.512061 -1.238927 -1.342982 -1.611635 2.071445 ... 0.025505 0.638072 0.104059 -0.600942 -0.367796 -0.472189 0.843934 0.706170 -1.676522 -0.266379 107 1.684027 -0.651413 -0.768050 0.599159 -0.381595 0.928799 2.188572 1.579998 -0.122685 -1.026440 ... -0.313672 1.276962 -1.142109 -0.145139 1.207923 -0.058557 -0.352806 1.506868 -2.296642 1.378678 118 -0.041210 -0.834533 -1.243622 -0.675754 -1.776586 0.038765 -2.713090 2.423366 -1.711815 0.621387 ... -1.063758 1.525688 -1.762023 0.161098 0.026806 0.462347 0.732975 0.479750 0.942445 -1.050575 129 0.708754 1.058510 0.297560 0.210548 0.460551 1.016141 2.554897 0.254032 0.935956 -0.250423 ... -0.552835 0.084124 0.437348 0.596228 0.512168 0.289721 -0.028321 -0.932675 -0.411235 1.035754 1310 -0.584553 1.395676 0.727354 0.641352 0.693481 -2.113973 -0.786199 -0.327758 1.278788 -0.156118 ... 1.204587 -0.131655 -0.595295 -0.433438 -0.863684 3.272247 0.101591 0.619058 -0.982174 -1.174125 1411 -0.753828 0.098016 -0.945322 0.708373 -1.493744 0.394732 0.075629 -0.049392 -1.005564 0.356353 ... 2.452891 -0.233571 0.398788 -1.597272 -1.919085 -0.405561 -0.266644 1.237022 1.079494 -2.292414 1512 -0.611864 0.006810 1.989711 -0.446170 -0.670108 0.045619 -0.092834 1.226774 -1.407549 -0.096695 ... 1.181310 -0.407162 -0.086341 -0.530628 0.042921 1.369478 0.823999 -0.312957 0.591755 0.516314 1613 -0.584553 1.395676 0.727354 0.641352 0.693481 -2.113973 -0.786199 -0.327758 1.278788 -0.156118 ... 1.204587 -0.131655 -0.595295 -0.433438 -0.863684 3.272247 0.101591 0.619058 -0.982174 -1.174125 1714 -1.174090 0.096075 -0.749195 0.395859 -0.622460 -1.291126 0.094431 0.680156 -0.480742 0.709318 ... 0.786663 0.237733 1.513797 0.296696 0.069533 -0.236719 1.098030 -0.442940 -0.583177 1.151497 1815 0.401740 -0.529587 3.016675 -1.134723 -0.256546 -0.219896 0.637936 2.000511 -0.418684 -0.242720 ... -0.442287 -1.519394 -1.007496 -0.517480 0.307449 -0.316039 -0.880636 -1.424680 -1.901644 1.968463 19 20[16 rows x 64 columns]
矩阵显示映射后的四行之一
我们对其余的 3 行执行相同的操作,最终我们有 4 组 x [16 行 x 64 列]。
这会导致形状为 (batch_size, context_length, d_model) = [4, 16, 64] 的输入嵌入矩阵。
从本质上讲,为每个单词提供唯一的嵌入允许模型适应语言的变化并管理具有多种含义或形式的单词。
让我们继续前进,理解我们的输入嵌入矩阵作为我们模型的预期输入格式,即使我们还没有完全掌握起作用的基本数学原理。
3 Positional Encoding
在我看来,位置编码是变压器架构中最具挑战性的概念。
总结一下位置编码解决了什么问题:
-
我们希望每个单词都带有一些关于它在句子中的位置的信息。
-
我们希望模型将看起来彼此接近的单词视为“接近”,将距离较远的单词视为“遥远”。
-
我们希望位置编码表示模型可以学习的模式。
位置编码描述序列中实体的位置或位置,以便为每个位置分配唯一的表示形式。
位置编码是另一个数字向量,它被添加到每个标记化字符的输入嵌入中。位置编码是正弦波和余弦波,其频率根据标记化字符的位置而变化。
在原始论文中,引入的位置编码计算方法是:
1PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) 2PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中 是位置,从 0 到 d_model/2。 是我们在训练模型时定义的模型维度(在我们的例子中是 64,在原始论文中他们使用 512)。pos``i``d_model
事实上,这个位置编码矩阵只创建一次,并重复用于每个输入序列。
让我们看一下位置编码矩阵:
让我们多谈谈位置编码技巧。
1Position Embedding Look-Up Table: 2 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 30 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 ... 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 41 0.841471 0.540302 0.681561 0.731761 0.533168 0.846009 0.409309 0.912396 0.310984 0.950415 ... 0.000422 1.000000 0.000316 1.000000 0.000237 1.000000 0.000178 1.000000 0.000133 1.000000 52 0.909297 -0.416147 0.997480 0.070948 0.902131 0.431463 0.746904 0.664932 0.591127 0.806578 ... 0.000843 1.000000 0.000632 1.000000 0.000474 1.000000 0.000356 1.000000 0.000267 1.000000 63 0.141120 -0.989992 0.778273 -0.627927 0.993253 -0.115966 0.953635 0.300967 0.812649 0.582754 ... 0.001265 0.999999 0.000949 1.000000 0.000711 1.000000 0.000533 1.000000 0.000400 1.000000 74 -0.756802 -0.653644 0.141539 -0.989933 0.778472 -0.627680 0.993281 -0.115730 0.953581 0.301137 ... 0.001687 0.999999 0.001265 0.999999 0.000949 1.000000 0.000711 1.000000 0.000533 1.000000 85 -0.958924 0.283662 -0.571127 -0.820862 0.323935 -0.946079 0.858896 -0.512150 0.999947 -0.010342 ... 0.002108 0.999998 0.001581 0.999999 0.001186 0.999999 0.000889 1.000000 0.000667 1.000000 96 -0.279415 0.960170 -0.977396 -0.211416 -0.230368 -0.973104 0.574026 -0.818837 0.947148 -0.320796 ... 0.002530 0.999997 0.001897 0.999998 0.001423 0.999999 0.001067 0.999999 0.000800 1.000000 107 0.656987 0.753902 -0.859313 0.511449 -0.713721 -0.700430 0.188581 -0.982058 0.800422 -0.599437 ... 0.002952 0.999996 0.002214 0.999998 0.001660 0.999999 0.001245 0.999999 0.000933 1.000000 118 0.989358 -0.145500 -0.280228 0.959933 -0.977262 -0.212036 -0.229904 -0.973213 0.574318 -0.818632 ... 0.003374 0.999994 0.002530 0.999997 0.001897 0.999998 0.001423 0.999999 0.001067 0.999999 129 0.412118 -0.911130 0.449194 0.893434 -0.939824 0.341660 -0.608108 -0.793854 0.291259 -0.956644 ... 0.003795 0.999993 0.002846 0.999996 0.002134 0.999998 0.001600 0.999999 0.001200 0.999999 1310 -0.544021 -0.839072 0.937633 0.347628 -0.612937 0.790132 -0.879767 -0.475405 -0.020684 -0.999786 ... 0.004217 0.999991 0.003162 0.999995 0.002371 0.999997 0.001778 0.999998 0.001334 0.999999 1411 -0.999990 0.004426 0.923052 -0.384674 -0.097276 0.995257 -0.997283 -0.073661 -0.330575 -0.943780 ... 0.004639 0.999989 0.003478 0.999994 0.002609 0.999997 0.001956 0.999998 0.001467 0.999999 1512 -0.536573 0.843854 0.413275 -0.910606 0.448343 0.893862 -0.940067 0.340989 -0.607683 -0.794179 ... 0.005060 0.999987 0.003795 0.999993 0.002846 0.999996 0.002134 0.999998 0.001600 0.999999 1613 0.420167 0.907447 -0.318216 -0.948018 0.855881 0.517173 -0.718144 0.695895 -0.824528 -0.565821 ... 0.005482 0.999985 0.004111 0.999992 0.003083 0.999995 0.002312 0.999997 0.001734 0.999998 1714 0.990607 0.136737 -0.878990 -0.476839 0.999823 -0.018796 -0.370395 0.928874 -0.959605 -0.281349 ... 0.005904 0.999983 0.004427 0.999990 0.003320 0.999995 0.002490 0.999997 0.001867 0.999998 1815 0.650288 -0.759688 -0.968206 0.250154 0.835838 -0.548975 0.042249 0.999107 -0.999519 0.031022 ... 0.006325 0.999980 0.004743 0.999989 0.003557 0.999994 0.002667 0.999996 0.002000 0.999998 19 20[16 rows x 64 columns]
据我了解,位置值是根据它们在序列中的相对位置建立的。此外,由于每个输入句子的上下文长度一致,它使我们能够在各种输入中回收相同的位置编码。因此,必须谨慎地创建序列号,以防止过大的幅度对输入嵌入产生负面影响,确保相邻位置表现出微小的差异,而远处的位置显示出它们之间的较大差异。
使用正弦和余弦向量的组合,该模型可以看到独立于词嵌入的位置编码向量,而不会混淆输入嵌入(语义)信息。很难想象这在神经元网络内部是如何工作的,但它是有效的。
我们可以可视化我们的位置嵌入数字并查看模式。
每条垂直线是我们从 0 到 64 的维度;每行代表一个字符。 这些值介于 -1 和 1 之间,因为它们来自正弦和余弦函数。颜色越深表示值越接近 -1,颜色越亮表示值越接近 1。绿色表示介于两者之间的值。
让我们回到我们的位置编码矩阵,正如你所看到的,这个位置编码表与输入嵌入表 [4, 16, 64] 中的每个批处理具有相同的形状,它们都是 (context_length, d_model) = [16, 64]。
由于两个具有相同形状的矩阵可以相加,因此我们可以将位置信息添加到每个输入嵌入行中,以获得最终输入嵌入矩阵。
1batch 0: 2 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 30 1.051807 0.295631 -0.913199 -0.151564 0.582201 0.101418 0.984299 0.924740 -0.004821 0.256358 ... 1.151378 1.119595 0.601200 0.059648 0.289960 1.579749 0.428623 1.263096 -0.773865 0.265780 41 0.547512 -0.738548 0.630830 1.594323 0.733316 -0.886616 0.783385 -0.216111 0.592187 -0.122698 ... -0.061995 0.559401 0.800599 1.783043 1.602587 0.323941 -0.246353 2.005651 -1.018534 1.604092 52 0.617101 -0.306899 0.865904 -0.629588 1.228581 -1.454339 0.596070 1.013263 -0.186154 1.793348 ... 0.383324 2.315575 -0.143404 2.280102 1.113303 1.438884 -0.275467 -1.226698 0.109251 1.701881 63 0.569062 -0.111243 0.601322 -0.079154 1.219661 -0.186289 -0.911600 1.774332 1.845533 1.278927 ... 1.271452 2.028822 -0.871380 0.852612 -0.082575 1.142617 -0.375369 0.898113 0.989920 0.937440 74 -1.821736 -0.785214 0.655805 -1.748969 1.072516 0.329445 1.969725 -1.593312 -0.423386 -0.870206 ... 0.232799 2.278685 0.255953 1.516287 0.622701 1.219178 1.346175 0.072133 0.510705 1.656851 85 1.555663 -0.717588 -0.179829 -1.666574 0.370867 -0.982811 2.255347 0.422208 0.123719 -0.034782 ... 0.091912 1.646094 -0.205354 1.187103 -1.287054 -0.068144 0.697607 0.626403 -0.333828 0.537782 96 0.219007 0.610934 -2.039364 -0.304516 1.144289 -1.485164 -0.664902 -2.161820 -0.664487 1.750649 ... 0.028036 1.638068 0.105957 0.399056 -0.366373 0.527810 0.845001 1.706170 -1.675722 0.733621 107 2.341013 0.102489 -1.627363 1.110608 -1.095316 0.228369 2.377153 0.597940 0.677737 -1.625878 ... -0.310720 2.276958 -1.139895 0.854859 1.209583 0.941441 -0.351562 2.506867 -2.295708 2.378678 118 0.948148 -0.980033 -1.523850 0.284180 -2.753848 -0.173272 -2.942995 1.450153 -1.137498 -0.197246 ... -1.060385 2.525683 -1.759494 1.161095 0.028703 1.462346 0.734397 1.479749 0.943511 -0.050575 129 1.120872 0.147380 0.746753 1.103982 -0.479273 1.357801 1.946789 -0.539822 1.227215 -1.207067 ... -0.549040 1.084117 0.440194 1.596224 0.514303 1.289719 -0.026721 0.067324 -0.410035 2.035753 1310 -1.128574 0.556604 1.664986 0.988980 0.080544 -1.323841 -1.665967 -0.803163 1.258105 -1.155904 ... 1.208804 0.868336 -0.592132 0.566557 -0.861313 4.272244 0.103369 1.619057 -0.980840 -0.174126 1411 -1.753818 0.102441 -0.022270 0.323699 -1.591020 1.389990 -0.921654 -0.123053 -1.336139 -0.587427 ... 2.457530 0.766419 0.402266 -0.597278 -1.916476 0.594436 -0.264688 2.237020 1.080961 -1.292415 1512 -1.148437 0.850664 2.402985 -1.356776 -0.221765 0.939481 -1.032902 1.567763 -2.015232 -0.890874 ... 1.186370 0.592825 -0.082546 0.469365 0.045767 2.369474 0.826133 0.687041 0.593355 1.516313 1613 -0.164386 2.303123 0.409138 -0.306666 1.549362 -1.596800 -1.504343 0.368137 0.454260 -0.721938 ... 1.210069 0.868330 -0.591184 0.566554 -0.860601 4.272243 0.103903 1.619056 -0.980440 -0.174127 1714 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 1815 1.052028 -1.289275 2.048469 -0.884570 0.579293 -0.768871 0.680185 2.999618 -1.418203 -0.211697 ... -0.435962 -0.519414 -1.002752 0.482508 0.311006 0.683955 -0.877969 -0.424683 -1.899643 2.968462 19 20[16 rows x 64 columns] 21 22batch 1: 23 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 240 -0.264236 0.965681 1.909974 -0.338721 -0.554196 0.254583 -0.576111 1.766522 -0.652587 0.455450 ... -1.016426 0.458762 -0.513290 0.618411 0.877229 2.526591 0.614551 0.662366 -1.246907 1.128066 251 1.732205 -0.858178 0.324008 1.022650 -1.172865 0.513133 -0.121611 2.630085 0.072425 2.332296 ... 0.737660 1.988225 2.544661 1.995471 0.447863 3.174428 0.444989 0.860426 2.137797 1.537580 262 -1.348308 -1.080221 1.753394 0.156193 0.440652 1.015287 -0.790644 1.215537 2.037030 0.476560 ... 0.296941 1.100837 -0.153194 1.329375 -0.188958 1.229344 -1.301919 0.938138 -0.860689 -0.860137 273 0.601103 -0.156419 0.850114 -0.324190 -0.311584 -2.232454 -0.903112 0.242687 0.801908 2.502464 ... -0.397007 1.150545 -0.473907 0.318961 -1.970126 1.967961 -0.186831 0.131873 0.947445 -0.281573 284 -1.821736 -0.785214 0.655805 -1.748969 1.072516 0.329445 1.969725 -1.593312 -0.423386 -0.870206 ... 0.232799 2.278685 0.255953 1.516287 0.622701 1.219178 1.346175 0.072133 0.510705 1.656851 295 1.555663 -0.717588 -0.179829 -1.666574 0.370867 -0.982811 2.255347 0.422208 0.123719 -0.034782 ... 0.091912 1.646094 -0.205354 1.187103 -1.287054 -0.068144 0.697607 0.626403 -0.333828 0.537782 306 0.599841 0.943214 -1.397184 -0.607349 -0.333995 -1.222589 -0.731189 -0.997706 1.848611 0.254238 ... 0.340986 1.383113 1.674592 2.229903 -0.157415 0.362868 -0.493762 1.904136 0.027903 1.196017 317 0.072234 1.386670 -0.985962 -1.184486 0.958293 -0.295773 -1.529277 -0.727844 1.510503 1.268154 ... -0.356459 0.382331 0.138104 -0.360916 -0.638448 1.305404 -0.756442 0.299150 0.154600 -0.466154 328 -0.008645 -1.066763 -0.716555 2.148885 -0.709739 -0.137266 0.385401 0.699139 1.907906 -2.357567 ... 0.490190 -1.215412 1.216459 0.659227 -0.282908 -0.912266 0.595569 1.210701 0.737407 0.801672 339 -0.006332 -0.949928 0.192689 3.158421 -1.292153 -0.830248 0.966141 -2.056514 0.042364 1.485927 ... 0.480763 -0.318554 0.005837 3.031636 -0.448117 1.059403 0.598106 0.871427 0.327321 1.090921 3410 -1.152681 -0.710162 -0.456591 -0.468090 -0.292566 0.747535 -0.149907 -0.395523 0.170872 -2.372754 ... -1.267461 0.043283 -0.114980 1.083042 -0.288776 1.442318 0.775591 0.728716 -0.576776 -0.727257 3511 -0.955986 -0.277475 0.946888 -0.242687 1.257744 0.369994 0.460073 0.728078 -0.165204 -0.761762 ... -0.307983 2.078995 -1.067792 1.805637 0.608968 1.722982 -0.371174 -0.603182 0.285387 1.112932 3612 -0.844347 0.883224 1.222388 -0.811387 -0.593557 0.157268 -0.650315 1.289236 -1.472027 -0.447092 ... -0.536433 2.465097 -0.822905 1.272786 0.703664 2.687270 -0.924388 0.596134 -0.367138 0.812242 3713 0.776470 1.549248 -0.239693 0.133783 0.767255 1.996130 -0.436228 -0.327975 -0.650743 0.507769 ... -0.821793 1.387792 -1.052105 2.123603 1.421092 2.066746 -0.747766 0.627081 -1.749071 -0.679443 3814 1.277579 0.653945 0.045632 -0.409790 0.829708 0.249433 -0.682051 0.601958 -1.932014 -2.077397 ... 0.160611 1.037856 0.656832 0.992817 -0.684056 1.031199 -0.180866 4.579140 -1.123555 0.181580 3915 0.356328 -2.038538 -1.018938 1.112716 1.035987 -2.281600 0.416325 -0.129400 -0.718316 -1.042091 ... -0.056092 0.559381 0.805026 1.783032 1.605907 0.323934 -0.243863 2.005648 -1.016667 1.604090 40 41[16 rows x 64 columns] 42 43batch 2: 44 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 450 0.645854 1.291073 -1.588931 1.814376 -0.185270 0.846816 -1.686862 0.982995 -0.973108 1.297203 ... 0.852600 1.533231 0.692729 2.437029 -0.178137 0.493413 0.597484 1.909155 1.257821 2.644325 461 1.732205 -0.858178 0.324008 1.022650 -1.172865 0.513133 -0.121611 2.630085 0.072425 2.332296 ... 0.737660 1.988225 2.544661 1.995471 0.447863 3.174428 0.444989 0.860426 2.137797 1.537580 472 3.298391 -0.363908 0.376535 -0.276692 1.262433 -0.595659 1.694541 0.542514 -0.464756 0.368460 ... -0.169474 1.420809 0.304488 1.689731 -1.128037 -0.024476 -1.356808 2.160992 -2.110703 -0.472404 483 0.626955 -2.988524 0.915578 1.123503 0.635983 0.078006 0.466728 -0.930765 2.189286 1.505499 ... 2.496649 1.691578 0.642664 2.089205 1.926187 1.185045 -0.969952 0.666007 -0.030641 0.667574 494 0.396447 -2.116415 0.384262 -1.632779 0.859029 -0.726599 2.121946 -1.314046 0.744388 -0.227106 ... -1.937352 2.378620 0.029220 1.215336 -0.405487 -0.834419 -1.219825 0.000676 -0.821293 0.340797 505 -2.133014 0.379737 -1.320323 -0.425003 -0.298524 -2.237205 0.953327 0.168006 0.519205 0.698976 ... 0.788771 1.237731 1.515378 1.296695 0.070718 0.763281 1.098920 0.557059 -0.582510 2.151497 516 -0.390918 0.634039 -1.350461 0.032129 0.106428 0.370410 1.292387 0.986316 -0.095396 0.555067 ... -1.792372 -0.357599 0.912276 0.088746 0.866950 0.927208 -0.381643 2.532119 0.464615 -1.044299 527 -0.407947 0.622332 -0.345048 -0.247587 -0.419677 0.256695 1.165026 -2.459640 -0.576545 -1.770781 ... 0.234064 2.278682 0.256901 1.516285 0.623413 1.219177 1.346708 0.072133 0.511105 1.656851 538 3.503946 -1.146751 0.111070 0.114221 -0.930330 -0.248769 1.166547 -0.038856 -0.301910 -0.843072 ... 0.093177 1.646091 -0.204405 1.187101 -1.286342 -0.068145 0.698141 0.626402 -0.333428 0.537781 549 -1.946920 -0.443788 0.560103 3.584257 -0.134643 -1.538940 -1.059084 -0.128679 2.503847 -2.244587 ... -0.643552 1.608934 -0.488734 -0.291253 1.633294 -0.018763 0.696360 -0.657761 0.692395 1.741288 5510 0.376520 0.583786 -0.705047 0.855548 0.471473 0.687240 -0.605646 0.463047 1.619052 -1.894214 ... -0.688652 1.974150 -1.399412 2.567682 -0.050040 1.782055 -0.297912 2.366196 -1.888527 0.635260 5611 -0.109256 -1.394054 0.565499 -0.093785 -1.803309 0.662382 -1.528203 1.644028 -0.569133 0.438101 ... 0.741877 1.988214 2.547823 1.995465 0.450234 3.174424 0.446768 0.860424 2.139130 1.537579 5712 -1.553993 -0.983421 0.392842 -1.473186 1.530387 1.894017 -0.732786 -1.601045 -0.740344 0.245303 ... -0.328828 3.013883 1.178296 1.263333 0.284824 0.791874 2.402131 -0.231270 -1.025411 0.178748 5813 -0.757965 1.771306 0.805440 -0.509121 1.212250 0.388750 -0.606959 2.352489 -2.445346 -0.103223 ... 0.425556 1.783019 0.698336 1.871530 2.314023 0.424368 -1.002745 0.983784 -0.090133 0.905337 5914 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 6015 -0.151101 -0.257150 -0.478131 -1.170082 1.318685 -0.188166 0.146375 2.895475 -0.918949 -0.305261 ... 1.623350 1.656103 -0.600456 1.039260 -1.944202 0.894911 1.409396 1.722673 -0.172070 2.265543 61 62[16 rows x 64 columns] 63 64batch 3: 65 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 660 0.377847 -0.380613 1.958640 0.224087 -0.420293 0.915635 -1.077748 1.255988 -0.223147 0.977568 ... -1.290532 1.460963 1.365088 -2.037483 -2.213841 1.039091 -2.129649 0.108403 -0.356996 2.239356 671 0.527961 0.342787 0.096746 0.885016 0.706699 2.873656 0.139732 0.497379 -0.009022 -0.147825 ... -0.409913 0.785146 -0.138166 2.041000 0.277500 1.578947 -1.535113 0.912230 -0.312735 0.540365 682 1.054965 -0.134411 2.155045 -0.188724 0.651576 -0.265663 -0.777263 0.571080 1.508661 1.021718 ... 0.762458 2.297400 -0.624743 -0.979212 2.024008 1.295633 0.208825 0.953138 -2.962624 1.586901 693 -1.032970 -0.893918 0.029077 -0.232068 0.370793 -1.407092 1.048066 0.981123 0.331907 1.292072 ... 0.787928 1.237732 1.514746 1.296695 0.070244 0.763281 1.098564 0.557060 -0.582777 2.151497 704 -0.980037 -1.014605 1.875135 -2.459635 0.486067 -0.941092 1.205490 1.248531 1.801383 0.576983 ... 0.192097 1.784109 -0.201023 0.405095 0.982041 1.927637 0.008535 1.063376 -1.439787 2.967185 715 -0.369996 -1.151058 -0.126222 0.768431 0.107524 -0.481010 2.056029 -0.872815 1.522675 -0.440916 ... 0.246007 -1.032684 0.572565 0.944744 0.790383 -0.034063 -1.704374 -0.053319 1.739537 2.381506 726 -0.555136 -0.284736 -0.162689 -1.542923 -1.619371 -2.014224 0.957231 -0.338164 1.353500 -2.048436 ... 0.180549 -0.598603 0.427175 1.845072 0.924364 -0.013093 -0.054108 -0.082885 -0.719218 0.960552 737 0.548834 1.130444 1.207497 0.565839 -1.814344 -0.111523 0.480270 -1.741823 1.451116 -0.977640 ... 1.692325 -0.708754 -0.747591 1.373189 -0.224415 -0.074035 -0.323435 2.001849 -1.102584 1.644658 748 0.117209 -0.905490 0.272336 0.994848 0.648951 0.354459 -0.731171 -1.641071 -0.966286 -0.837498 ... 0.294006 1.008774 1.376944 2.969555 0.997452 2.076708 0.631358 1.080600 0.075384 1.819302 759 0.557786 -0.629395 1.606758 0.633762 -1.190379 -0.355466 -2.132275 -0.887707 1.208793 -0.741505 ... 0.765410 2.297393 -0.622529 -0.979216 2.025668 1.295631 0.210070 0.953136 -2.961691 1.586900 7610 1.107697 -2.050459 1.399869 1.271179 -1.391529 1.103020 -0.910370 -0.398901 -0.803458 -2.081302 ... 1.462017 -0.115730 0.171052 0.594118 0.514388 1.593223 0.064085 -0.029184 -0.044621 1.206415 7711 -1.771933 0.469475 0.961730 0.002798 1.386089 0.250342 -0.062900 -0.569053 -2.149857 -0.519952 ... -0.725692 -0.727693 -0.178683 1.675822 -0.401712 1.109331 0.980627 -0.357667 -0.484853 0.208340 7812 -1.518213 1.899549 -0.320427 -0.929415 -0.701020 0.727833 -2.764498 0.612756 0.041370 -1.599998 ... -0.136314 1.068995 0.635501 0.765369 0.270007 0.319588 -0.652992 1.322658 1.724227 2.343042 7913 0.094923 0.575470 -0.852224 -2.098593 0.998579 0.347285 -0.467688 0.773722 -1.664829 -0.412623 ... -1.274262 0.454381 -1.142107 1.853844 -1.912537 0.544311 0.667555 -1.187468 1.291108 2.275956 8014 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 8115 2.053710 -2.769740 -0.148796 0.983717 -0.038190 -0.655360 1.826909 -0.332533 -1.036128 -1.001430 ... 0.674310 0.695848 -0.181635 1.051397 -0.884897 1.590696 -1.375117 0.596254 -0.651398 0.797715 82 83[16 rows x 64 columns]
最终输入嵌入将馈送到 Transformer 解码器模块进行训练。
这个最终结果矩阵称为位置输入嵌入,其形状为 (batch_size, context_length, d_model) = [4, 16, 64]。
这都是关于位置编码的。
但是为什么要使用正弦和余弦函数来编码位置呢?为什么不只是一个随机数?为什么这两个数字相加可以同时包含其含义和位置信息?好吧,一开始我也有同样的疑惑。然而,我发现没有必要完全理解它的基础数学来训练一个模型。因此,如果您渴望详细解释,请参阅单独的部分或观看我的视频剪辑。我将就此结束,然后继续下一步。
到目前为止,我们已经介绍了模型的输入编码和位置编码部分。让我们转到变压器块。
4 Transformer Block
Transformer 模块是由三层组成的堆栈:一个屏蔽的多头注意力机制、两个归一化层和一个前馈网络。
蒙面的多头注意力是一组自我注意,每个自我注意都称为一个头。因此,让我们先来看看自我注意力机制。
4.1 Multi-Head Attention Overview
变形金刚的力量来自一种叫做自我关注的东西。通过自我关注,模型密切关注输入中最关键的部分。每个部分都称为一个头部。
这是磁头的工作原理:磁头通过三个独特的层(称为查询 (Q)、键 (K) 和值 (V) 处理输入来工作。它首先比较 Q 和 K,调整结果,然后使用这些比较创建一组分数,显示重要内容。然后使用这些分数来权衡 V 中的信息,从而更加关注重要部分。头部的学习来自于随着时间的推移调整这些 Q、K 和 V 层中的设置。
多头注意力只是由几个单独的头堆叠在一起组成。所有磁头都接收到完全相同的输入,尽管它们在计算过程中使用自己特定的权重集。处理输入后,所有磁头的输出被连接起来,然后通过线性层。
下图提供了头部内过程的可视化表示,以及多头注意力模块中的详细信息。
为了进行证明计算,让我们从原始论文“注意力是你所需要的”中引入公式:
从公式中,我们首先需要三个矩阵:Q(查询)、K(键)和 V(值)。要计算注意力分数,我们需要执行以下步骤:
-
将 Q 乘以 K 转置(表示为 K^T)
-
除以 K 维数的平方根
-
应用 SoftMax 函数
-
乘以 V
我们将一一介绍。
4.2 Prepare Q,K,V
计算注意力的第一步是获取 Q、K 和 V 矩阵,分别表示查询、键和值。这三个值将用于我们的注意力层来计算注意力概率(权重)。这些是通过将上一步中的位置输入嵌入矩阵(表示为 X)应用于标记为 Wq、Wk 和 Wv 的三个不同的线性层来确定的(所有值都是随机分配的,首先可学习)。然后将每个线性层的输出拆分为多个磁头,表示为 num_heads,这里我们选择 4 个磁头。
Wq、Wk、Wv 是三个矩阵,维度为 (d_model, d_model) = [64, 64]。所有值都是随机分配的。这在神经网络中称为线性层或可训练参数。可训练参数是模型在训练期间将学习和自我更新的值。
为了获得我们的 Q,K,V 值,我们在输入嵌入矩阵 X 和三个矩阵 Wq、Wk、Wv 中的每一个之间进行矩阵乘法(再次,它们的初始值是随机分配的)。
-
Q = X*Wq
-
K = X*周
-
V = X*Wv
上述函数的计算(矩阵乘法)逻辑:
X 的形状为 (batch_size, context_length, d_model) = [4, 16, 64],我们将其分解为 4 个形状为 [16, 64] 的子矩阵。而 Wq、Wk、Wv 的形状为 (d_model, d_model) = [64, 64]。我们可以对 4 个 X 的子矩阵中的每一个进行矩阵乘法,以 Wq、Wk、Wv 为单位。
如果回想一下线性代数,则只有当第一个矩阵中的列数等于第二个矩阵中的行数时,才有可能对两个矩阵进行乘法。在我们的例子中,X 中的列数是 64,Wq、Wk、Wv 中的行数也是 64。因此,乘法是可能的。
矩阵乘法得到 4 个形状为 [16, 64] 的子矩阵的形状,可以组合表示为 (batch_size, context_length, d_model) = [4, 16, 64]。
现在,我们的 Q、K、V 矩阵的形状为 (batch_size, context_length, d_model) = [4, 16, 64]。接下来,我们需要将它们拆分为多个头。这就是为什么变压器架构将其命名为多头注意力的原因。
劈头只是意味着在d_model的 64 个维度中,我们将它们切割成多个头部,每个头部包含一定数量的维度。每个头部都将能够学习输入的某些模式或语义。
假设我们将num_heads也设置为 4。这意味着我们将 Q、K、V 形状为 [4, 16, 64] 的矩阵拆分为多个子矩阵。
实际的拆分是通过将 64 的最后一个维度重塑为 16 的 4 个子维度来完成的。
每个 Q、K、V 矩阵从形状 [4, 16, 64] 转换为 [4, 16, 4, 16]。最后两个维度是头部。换句话说,它从以下转变而来:
[batch_size、context_length、d_model]
自:
[batch_size、context_length、num_heads、head_size]
要理解具有相同形状的 Q、K 和 V 矩阵 [4, 16, 4, 16],请考虑以下观点:
在管道中,有四个批次。每批由 16 个代币(单词)组成。对于每个标记,有 4 个头,每个头编码 16 个维度的语义信息。
4.3 Calculate Q,K Attention
现在我们已经有了 Q、K 和 V 这三个矩阵,让我们开始逐步计算单头注意力。
从变压器图中,Q 和 K 矩阵首先相乘。
现在,如果我们丢弃 Q 和 K 矩阵中的batch_size,只保留最后三个维度,现在 Q = K = V = [context_length, num_heads, head_size] = [16, 4, 16]。
我们需要在前两个维度上再做一个转置,使它们的形状为 Q = K = V = [num_heads, context_length, head_size] = [4 ,16, 16]。这是因为我们需要在最后两个维度上进行矩阵乘法运算。
Q * K^T = [4, 16, 16] * [4, 16, 16] = [4, 16, 16]
我们为什么要这样做?此处的转置是为了促进不同上下文之间的矩阵乘法。用图表解释更直接。最后两个维度表示为 [16, 16],可以可视化如下:
这个矩阵,其中每行和每列在我们的例句的上下文中代表一个标记(单词)。矩阵乘法是衡量上下文中每个单词与所有其他单词之间的相似性。该值越高,它们越相似。
让我提出一个注意力得分的头:
1[ 0.2712, 0.5608, -0.4975, ..., -0.4172, -0.2944, 0.1899], 2[-0.0456, 0.3352, -0.2611, ..., 0.0419, 1.0149, 0.2020], 3[-0.0627, 0.1498, -0.3736, ..., -0.3537, 0.6299, 0.3374], 4 ..., ..., ..., ..., ..., ..., ..., 5 ..., ..., ..., ..., ..., ..., ..., 6[-0.4166, -0.3364, -0.0458, ..., -0.2498, -0.1401, -0.0726], 7[ 0.4109, 1.3533, -0.9120, ..., 0.7061, -0.0945, 0.2296], 8[-0.0602, 0.2428, -0.3014, ..., -0.0209, -0.6606, -0.3170] 9 10[16 rows x 16 columns]
这个 16 x 16 矩阵中的数字代表我们的例句 “ ” 的注意力分数。. By mastering the art of identifying underlying motivations and desires, we equip ourselves with
更容易看作一个情节:
横轴代表 Q 的头之一,纵轴表示 K 的头之一,彩色方块表示上下文中每个令牌和彼此令牌之间的相似性分数。颜色越深,相似度越高。
当然,上面显示的相似之处现在没有多大意义,因为这些只是来自随机分配的值。但是经过训练,相似性分数将是有意义的。
好了,现在让我们把批次维度 batch_size 带回 Q*K 注意力分数。最终结果的形状为 [batch_size, num_heads, context_length, head_size],即 [4, 4, 16, 16]。
这是当前步骤的 Q*K 注意力分数。
4.4 Scale
量表部分很简单,我们只需要将 Q*K^T 注意力分数除以 K 维度的平方根即可。
在这里,我们的 K 维数等于 Q 的维数,d_model除以 num_heads:64/4 = 16。
然后我们取 16 的平方根,即 4。并将 Q*K^T 注意力得分除以 4。
这样做的原因是为了防止 Q*K^T 注意力分数过大,这可能会导致 softmax 函数饱和,进而导致梯度消失。
4.5 Mask
在仅解码器转换器模型中,掩蔽的自我注意力本质上充当序列填充。
解码器只能查看以前的字符,而不能查看未来的字符。因此,未来的字符被屏蔽并用于计算注意力权重。
如果我们再次可视化情节,这很容易理解:
空格表示 0 分,被屏蔽了
多头注意力层中屏蔽的要点是防止解码器“看到未来”。在我们的例句中,解码器只允许看到当前单词和它之前的所有单词。
4.6 Softmax
softmax 步骤将数字更改为一种特殊的列表,其中整个列表加起来为 1。它增加了高数字并减少了低数字,从而创造了明确的选择。
简而言之,softmax 函数用于将线性层的输出转换为概率分布。
在现代深度学习框架(如 PyTorch)中,softmax 函数是一个内置函数,使用起来非常简单:
这行代码会将 softmax 应用于我们在上一步中计算的所有注意力分数,并产生介于 0 和 1 之间的概率分布。
让我们也提出应用softmax后同一头的注意力分数:
现在,所有概率分数均为正数,加起来为 1。
4.7 Calculate V Attention
最后一步是将 softmax 输出乘以 V 矩阵。
请记住,我们的 V 矩阵还将其拆分为多个头,形状为 (batch_size, num_heads, context_length, head_size) = [4, 4, 16, 16]。
而上一个 softmax 步骤的输出为 (batch_size, num_heads, context_length, head_size) = [4, 4, 16, 16]。
在这里,我们对两个矩阵的最后两个维度执行另一个矩阵乘法。
softmax_output * V = [4, 4, 16, 16] * [4, 4, 16, 16] = [4, 4, 16, 16]
结果的形状为 [batch_size, num_heads, context_length, head_size] = [4, 4, 16, 16]。
我们称此结果为 A。
4.8 Concatenate and Output
我们多头注意力的最后一步是将所有头连接在一起,并将它们穿过线性层。
串联的理想是将来自所有头部的信息组合在一起。因此,我们需要将 A 矩阵从 [batch_size, num_heads, context_length, head_size] = [4, 4, 16, 16] 重塑为 [batch_size, context_length, num_heads, head_size] = [4, 16, 4, 16]。原因是我们需要将最后两个维度放在一起,因此可以很容易地将它们(通过矩阵乘法)组合回大小。num_heads``head_size``d_model = 64
这可以通过 PyTorch 的内置函数轻松完成:
1A = A.transpose(1, 2) # [4, 16, 4, 16] [batch_size, context_length, num_heads, head_size]
接下来,我们需要将最后两个维度 [num_heads, head_size] = [4, 16] 组合到 [d_model] = [64]。
1A = A.reshape(batch_size, -1, d_model) # [4, 16, 64] [batch_size, context_length, d_model]
正如你所看到的,经过一系列的计算,我们的结果矩阵 A 现在回到了与我们的输入嵌入矩阵 X 相同的形状,即 [batch_size, context_length, d_model] = [4, 16, 64]。由于此输出结果将作为输入传递到下一层,因此必须保持输入和输出相同的形状。
但在将其传递到下一层之前,我们需要对它执行另一个线性变换。这是通过在串联矩阵 A 和 Wo 之间执行另一个矩阵乘法来完成的。
这个 Wo 被随机分配了形状 [d_model, d_model],并将在训练期间更新。
输出 = A* Wo = [4, 16, 64] * [64, 64] = [4, 16, 64]
线性层的输出是单头注意力的输出,表示为输出。
祝贺!现在我们已经完成了蒙面的多头注意力部分!让我们开始变压器块的其余部分。这些都很简单,所以我会快速浏览它们。
5 Residual Connection and Layer Normalization
残差连接(有时称为跳过连接)是允许原始输入 X 绕过一个或多个层的连接。
这只是原始输入 X 和多头注意力层输出的总和。由于它们的形状相同,因此将它们相加很简单。
output = output + X
残差连接后,该过程进入层归一化。LayerNorm 是一种对网络中每一层的输出进行规范化的技术。这是通过减去平均值并除以图层输出的标准差来完成的。此技术用于防止层的输出变得太大或太小,这可能导致网络变得不稳定。
残差连接和层归一化在“Attention is All You Need”的原始论文中表示。Add & Norm
6 Feed-Forward Network
一旦我们有了归一化的注意力权重(概率分数),它将通过一个位置前馈网络进行处理。
前馈网络 (FFN) 由两个线性层组成,它们之间具有 ReLU 激活函数。让我们看看 python 代码是如何实现的:
1# Define Feed Forward Network 2output = nn.Linear(d_model, d_model * 4)(output) 3output = nn.ReLU()(output) 4output = nn.Linear(d_model * 4, d_model)(output)
将 ChatGPT 解释为上述代码:
-
输出 = nn。Linear(d_model, d_model * 4)(输出):这将对传入数据应用线性变换,即 y = xA^T + b。输入和输出大小分别为 d_model 和 d_model * 4。此转换增加了输入数据的维度。
-
输出 = nn。ReLU()(输出):这在元素上应用整流线性单元 (ReLU) 函数。它被用作激活函数,将非线性引入模型,使其能够学习更复杂的模式。
-
输出 = nn。Linear(d_model * 4, d_model)(输出):这将应用另一个线性变换,将维数降低到d_model。这种“先扩张后收缩”是神经网络中的常见模式。
作为机器学习或LLM的新手,像你和我一样,可能会被这些解释所迷惑。当我第一次遇到这些术语时,我得到了完全相同的感觉。
但不用担心,我们可以这样理解:这个前馈网络只是一个标准的神经网络模块,它的输入和输出都是注意力分数。其目的是将注意力分数的维度从 64 扩展到 256,这使得信息更加精细,并使模型能够学习更复杂的知识结构。然后,它将尺寸压缩回 64,使其适用于后续计算。
7 Step 7 : Repeat Step 4 to 6
凉!我们已经完成了第一个变压器模块。现在,我们需要对我们想要的其余变压器块重复相同的过程。
在头部方面,我引用了 HuggingChat 的 AI 回应:
GPT-2 在其最大配置 (GPT-2-XL) 中使用 48 个转换器块,而较小的配置具有较少的转换器块(GPT-2-Large 为 36 个,GPT-2-Medium 为 24 个,GPT-2-Small 为 12 个)。每个变压器模块都包含一个多头自注意力机制,然后是按位置的前馈网络。这些转换器模块可帮助模型捕获长程依赖关系并生成连贯的文本。
通过具有多个模块,输出被训练并作为输入 X 传递到下一个模块,因此在迭代后,模型可以学习输入序列中单词之间更复杂的模式和关系。
8 Output Probabilities
在推理过程中,您希望从模型中获取下一个预测的标记,但到目前为止,我们得到的实际上是词汇表中所有标记的概率分布。你还记得上面的例子中我们的词汇量是 3,771 吗?因此,为了选择一个最高概率标记,我们将形成一个矩阵,其模型维度的大小 d_model = 64 乘以我们的 vocab_size = 3,771。这一步在训练上与在推理上没有区别。
1# Apply the final linear layer to get the logits 2logits = nn.Linear(d_model, vocab_size)(output)
我们将这个线性层之后的输出称为 logits。logits 是形状为 [batch_size, context_length, vocab_size] = [4, 16, 3771] 的矩阵。
然后使用最终的softmax函数将线性层的logits转换为概率分布。
1logits = torch.softmax(logits, dim=-1)
注意:在训练过程中,我们不需要在这里应用softmax函数,而是使用nn。CrossEntropy 函数,因为它内置了 softmax 行为。
我们如何查看形状 [4, 16, 3771] 的结果对数?实际上,经过所有计算,这是一个非常简单的想法:
我们有 4 个批处理管道,每个管道包含该输入序列中的所有 16 个单词,每个单词映射到词汇表中其他每个单词的概率。
如果模型在训练中,我们更新这些概率参数,如果模型在推理中,我们只需选择概率最高的一个。那么一切都有意义了。
总结
Transformer 架构的复杂性可能具有挑战性。如果想要深入了解,还需要结合实际代码多做尝试,我会在接下来的时间里,结合代码来说明Transformer 架构。